Project realized during Advanced Data Mining course at Poznan University of Technology, Winter Semester 2024/2025.
Course led by Dariusz Brzeziński PhD. and Witold Taisner M.Sc.
knitr::opts_chunk$set(echo = TRUE, warning = FALSE, message = FALSE)
library(ggplot2)
library(tidyr)
library(dplyr)
library(tibble)
library(reshape2)
library(plotly)
library(timechange)
library(ModelMetrics)
library(caret)
set.seed(413)
Energy storage is very important in current times. Batteries are used in every day appliances, such as phones, clocks and headphones. They are also more and more present in transport, being the base of electric vehicles, which become more and more popular. This poses series of tasks ahead of developers, that needs to be met before a battery can be put to use.
Most important features of the batteries are their ability to recharge and time of said recharge. Capacity is also crucial, but rather in relation to the mass and volume of the battery.
Researcher try to find new composites and compounds, that would allow batteries to be lighter, smaller, more versatile and faster in recharge.
This project aims to summaries some of the features of existing batteries, their distribution, finding correlation between those attributes, especially in accordance to working ion classification. This classification is at the end used as a goal of prediction model, which may be used to test what was inside a battery based on its characteristics.
Most important results presents as follow:
Data used in this project come from US Department of Energy Materials Project. Original data is collected to help researchers develop new materials for construction, energy production and storage, and improving and testing the existing ones. It is distributed in open access form based on Creative Commons licence.
For the analysis only a fracture of the dataset will be used, one concerning batteries and their parameters.
Even though dataset came preprocessed by authors and cleaned, some
cleaning operations are executed, to make sure it follows the needed
format. Data is read from csv file and saved in a cached variable.
Records, that miss id and Formula columns are
omited. In numeric columns, nulls and other missing values are replaced
by 0, and in text columns, they are replaced by "?" (a
string consisting of single question mark). Exeisting in the file
headers are kept in the data frame.
mp <- read.csv("data/mp_batteries.csv", header = T) %>%
drop_na(Battery.ID, Battery.Formula) %>%
mutate_if(is.numeric, ~replace_na(., 0)) %>%
mutate_if(is.character, ~replace_na(., "?"))%>%
column_to_rownames(var="Battery.ID")
mp$Working.Ion <- as.factor(mp$Working.Ion)
To check, how data is constructed, first couple rows are shown.
knitr::kable(head(mp, 10))
| Battery.Formula | Working.Ion | Formula.Charge | Formula.Discharge | Max.Delta.Volume | Average.Voltage | Gravimetric.Capacity | Volumetric.Capacity | Gravimetric.Energy | Volumetric.Energy | Atomic.Fraction.Charge | Atomic.Fraction.Discharge | Stability.Charge | Stability.Discharge | Steps | Max.Voltage.Step | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| mp-30_Al | Al0-2Cu | Al | Cu | Al2Cu | 3.0433992 | 0.0890331 | 1368.48055 | 5562.7901 | 121.840086 | 495.272533 | 0.0000000 | 0.6666667 | 0.0000000 | 0.0000000 | 1 | 0 |
| mp-1022721_Al | Al1-3Cu | Al | AlCu | Al3Cu | 1.2436528 | -0.0215863 | 1112.93655 | 4418.9798 | -24.024232 | -95.389622 | 0.5000000 | 0.7500000 | 0.0740612 | 0.0962458 | 1 | 0 |
| mp-8637_Al | Al0-5Mo | Al | Mo | Al5Mo | 4.7625743 | 0.1227568 | 1741.50416 | 7175.7017 | 213.781556 | 880.866507 | 0.0000000 | 0.8333333 | 0.4114601 | 0.0452120 | 1 | 0 |
| mp-129_Al | Al0-12Mo | Al | Mo | Al12Mo | 12.7238931 | 0.0431214 | 2298.81076 | 7346.2323 | 99.128013 | 316.780060 | 0.0000000 | 0.9230769 | 0.0000000 | 0.0114456 | 1 | 0 |
| mp-91_Al | Al0-12W | Al | W | Al12W | 12.4945977 | 0.0292342 | 1900.74513 | 7332.7186 | 55.566774 | 214.366205 | 0.0000000 | 0.9230769 | 0.0000000 | 0.0000000 | 1 | 0 |
| mp-1055908_Al | Al0-12Mn | Al | Mn | MnAl12 | 18.2361563 | 0.0397314 | 2547.69280 | 7592.9161 | 101.223298 | 301.676876 | 0.0000000 | 0.9230769 | 0.1454643 | 0.0000000 | 1 | 0 |
| mp-2658_Al | Al0-1Fe | Al | Fe | AlFe | 0.7711539 | 0.4717287 | 970.75702 | 5622.3562 | 457.933974 | 2652.226958 | 0.0000000 | 0.5000000 | 0.7613994 | 0.0000000 | 1 | 0 |
| mp-16722_Al | Al1-10.25V | Al | Al10V | Al41V4 | 0.0027108 | -0.0155827 | 61.37701 | 176.4151 | -0.956421 | -2.749028 | 0.9090909 | 0.9111111 | 0.0118097 | 0.0125861 | 1 | 0 |
| mp-998981_Al | Al1-3Ti | Al | TiAl | TiAl3 | 0.9562924 | 0.1602450 | 1248.40362 | 4248.4211 | 200.050419 | 680.788169 | 0.5000000 | 0.7500000 | 0.1415912 | 0.0244962 | 1 | 0 |
| mp-8633_K | K0-3Cr | K | Cr | K3Cr | 15.8029363 | -0.7487069 | 474.94813 | 667.5593 | -355.596958 | -499.806269 | 0.0000000 | 0.7500000 | 0.4025263 | 0.6621618 | 1 | 0 |
The dataset consists of 17 columns:
For each text fields, number of unique values is counted and the most frequent value is retrieved. The total number of records in the set is also presented.
most_frequent <- function(x) {
tbl <- table(x)
mode_value <- names(tbl)[which.max(tbl)]
mode_value
}
mp %>% summarise(n())
## n()
## 1 4351
summaryDF <- data.frame(uniqueCount = numeric(), mostFrequent = character(), stringsAsFactors = FALSE)
cols <- mp %>% select(Battery.Formula:Formula.Discharge) %>% colnames()
for (columnName in cols){
row <- mp %>% summarise(
uniqueCount = n_distinct(across(columnName)),
mostFrequent = most_frequent(across(columnName))
)
summaryDF <- rbind(summaryDF, row)
}
rownames(summaryDF) <- cols
summaryDF %>% knitr::kable()
| uniqueCount | mostFrequent | |
|---|---|---|
| Battery.Formula | 3301 | Li0-1V2OF5 |
| Working.Ion | 10 | Li |
| Formula.Charge | 2096 | MnO2 |
| Formula.Discharge | 3173 | LiCoPO4 |
For numeric fields, basic statistics are calculated. Presented are mean, median, standard deviation and quartiles, as well as minimal and maximal value of each category. As seen, most of the categories has outliers, that should be removed for better results.
summaryDF <- data.frame(
mean_value = numeric(), median_value = numeric(), sd_value = numeric(),
'1st. quartile' = numeric(), '3rd. quartile' = numeric(),
min_value = numeric(), max_value = numeric(), stringsAsFactors = FALSE
)
cols <- mp %>% select(Max.Delta.Volume:Stability.Discharge) %>% colnames()
for (columnName in cols){
row <- mp %>% summarise(
mean_value = mean(.data[[columnName]], na.rm = TRUE),
median_value = median(.data[[columnName]], na.rm = T),
sd_value = sd(.data[[columnName]], na.rm = TRUE),
'1st. quartile' = quantile(.data[[columnName]], 0.25),
'3rd. quartile' = quantile(.data[[columnName]], 0.75),
min_value = min(.data[[columnName]]),
max_value = max(.data[[columnName]])
)
summaryDF <- rbind(summaryDF, row)
}
rownames(summaryDF) <- cols
summaryDF %>% knitr::kable()
| mean_value | median_value | sd_value | 1st. quartile | 3rd. quartile | min_value | max_value | |
|---|---|---|---|---|---|---|---|
| Max.Delta.Volume | 0.3753137 | 0.0420271 | 6.8518375 | 0.0174699 | 0.0859531 | 0.0000162 | 2.931932e+02 |
| Average.Voltage | 3.0831427 | 3.3005818 | 1.8220562 | 2.2257406 | 4.0186221 | -7.7547512 | 5.456883e+01 |
| Gravimetric.Capacity | 158.2908894 | 130.6909797 | 164.9136411 | 88.1080335 | 187.6001100 | 5.1765430 | 2.557627e+03 |
| Volumetric.Capacity | 610.6240987 | 507.0312049 | 563.8531258 | 311.6147847 | 722.7548973 | 24.0790699 | 7.619191e+03 |
| Gravimetric.Energy | 444.1063802 | 401.7876573 | 351.0481297 | 211.6921486 | 614.4157432 | -583.5458444 | 5.926950e+03 |
| Volumetric.Energy | 1664.0484137 | 1463.7877150 | 1297.7985678 | 821.6252773 | 2252.2567782 | -2208.0745659 | 1.830590e+04 |
| Atomic.Fraction.Charge | 0.0398558 | 0.0000000 | 0.0885604 | 0.0000000 | 0.0476190 | 0.0000000 | 9.090909e-01 |
| Atomic.Fraction.Discharge | 0.1590772 | 0.1428571 | 0.1203743 | 0.0869565 | 0.2000000 | 0.0074074 | 9.933333e-01 |
| Stability.Charge | 0.1425666 | 0.0731920 | 0.3782776 | 0.0330126 | 0.1316053 | 0.0000000 | 6.487098e+00 |
| Stability.Discharge | 0.1220717 | 0.0487845 | 0.3523182 | 0.0195235 | 0.0929907 | 0.0000000 | 6.277809e+00 |
First step in this detailed analysis will be to determine the outliers in numeric data. The report will present the threshold, calculated as mean plus three times standard deviation, and the number of the values considered as outliers. Please note, that both attributes concerning steps are integer, so the threshold value should be rounded up to the closest value.
As seen, outliers are present in each attributes, and for some, they make quite a numerous group within the whole population, but never exceeding one hundred examples.
This knowledge is applied to histograms, generated in next step, to improve their readability.
outlier_thresholds <- mp %>% select_if(is.numeric) %>%
summarise(across(
everything(),
list(
upperOutlierThreshold = ~ mean(.x, na.rm = TRUE) + 3 * sd(.x, na.rm = TRUE)
)
)) %>%
pivot_longer(
cols = everything(),
names_to = "column_name",
values_to = "outliers_threshold"
)
outlier_thresholds$column_name <- sub("_.*", "", outlier_thresholds$column_name)
outlier_thresholds <- outlier_thresholds %>%
column_to_rownames(var='column_name')
outliers_count <- mp %>% select_if(is.numeric) %>%
summarise(across(
everything(),
list(outliers = ~ sum(abs(scale(.x)) > 3, na.rm = TRUE))
)) %>%
pivot_longer(
cols = everything(),
names_to = "column_name",
values_to = "outliers_threshold"
)
outliers_count$column_name <- sub("_.*", "", outliers_count$column_name)
outliers_count <- outliers_count %>%
column_to_rownames(var='column_name')
summaryDF <- cbind(outlier_thresholds, outliers_count)
summaryDF %>% knitr::kable()
| outliers_threshold | outliers_threshold | |
|---|---|---|
| Max.Delta.Volume | 20.9308261 | 4 |
| Average.Voltage | 8.5493113 | 18 |
| Gravimetric.Capacity | 653.0318129 | 43 |
| Volumetric.Capacity | 2302.1834762 | 62 |
| Gravimetric.Energy | 1497.2507692 | 40 |
| Volumetric.Energy | 5557.4441170 | 39 |
| Atomic.Fraction.Charge | 0.3055370 | 81 |
| Atomic.Fraction.Discharge | 0.5202002 | 94 |
| Stability.Charge | 1.2773994 | 52 |
| Stability.Discharge | 1.1790262 | 73 |
| Steps | 2.5583368 | 99 |
| Max.Voltage.Step | 2.0404938 | 66 |
To gain some additional knowledge on the data, histograms for all the numerical fields are presented. As seen they follow a highly skewed normal distribution (meaning they are mostly aligned around median of the set, with said median being shifted to one side of the spectrum, which results in the concentration of examples on one side of the center point, and sparser distribution on the other), with outliers appearing mostly on the upper end of the spectrum.
In case of first four fields, default range is applied, covering all the case. The rest of the histograms are clipped, removing some very far off outlier cases, that impaired the readability.
cols <- mp %>% select(Gravimetric.Capacity:Volumetric.Energy) %>% colnames()
for (columnName in cols) {
graph <- mp %>% ggplot(aes_string(x = columnName)) +
geom_histogram(binwidth = 60, fill = 'blue', color = 'black', alpha = 0.7) +
labs(title = paste("Histogram of", columnName), x = columnName, y = "Frequency") +
theme_minimal()
plot(graph)
}
graph <- mp %>% ggplot(aes(x = Max.Delta.Volume)) +
geom_histogram(binwidth = 0.01, fill = 'blue', color = 'black', alpha = 0.7) +
labs(title = "Histogram of Max.Delta.Volume", x = 'Max.Delta.Volume', y = "Frequency") +
xlim(-0.1, 2) +
theme_minimal()
plot(graph)
cols <- mp %>% select(Atomic.Fraction.Charge:Stability.Discharge) %>% colnames()
for (columnName in cols) {
graph <- mp %>% ggplot(aes_string(x = columnName)) +
geom_histogram(binwidth = 0.1, fill = 'blue', color = 'black', alpha = 0.7) +
xlim(-0.1,1.5)+
labs(title = paste("Histogram of", columnName), x = columnName, y = "Frequency") +
theme_minimal()
plot(graph)
}
To show correlation graphs, it is important to first check the correlation coefficient between the fields. The result is presented on the matrix, with darker colors corresponding to better correlated data. The fields form pairs of highly correlated values (Charge/Discharge pairs, energy, capacity). There can be also observed quite good correlation between energy and capacity. The less obvious pairs exist between energy and voltage, and capacity and atomic fraction discharge, but they are probably related to physics behind battery.
cor_matrix <- cor(mp %>% select_if(is.numeric), use = "complete.obs", method = "pearson")
cor_matrix_melted <- melt(cor_matrix)
ggplot(data = cor_matrix_melted, aes(x = Var1, y = Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(low = "red", high = "blue", mid = "white",
midpoint = 0, limit = c(-1, 1), space = "Lab",
name = "Pearson\nCorrelation") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
coord_fixed() +
labs(title = "Correlation Matrix", x = "Fields", y = "Fields")
The four pair mentioned above are plotted on scatter plots with trend line included. This leads to interesting results. average voltage and energy has good value of correlation and shows it on the graph, which is clearly centered around the trend line. However, after removing outliers (i.e. filtering voltage to values less then 20), the distribution loses its coherence and become more of a conical shape (as shown on the volumetric graph).
On the other hand atomic fraction discharge mapped against capacity does not present itself so well. First and foremost, it is clearly visible, that the distribution is much denser in certain points on the x axis (discharge fraction). The values in this places rise accordingly, but still present a spectrum rather than single point observation. Other interesting observation is that the point follow more of a quadratic distribution than a linear one.
mp %>% ggplot(aes(x = Average.Voltage, y=Gravimetric.Energy)) +
geom_point(color = "blue") + # Scatter points
geom_smooth(method = "lm") + # Regression line
labs(title = "Correlation between Average Voltage and Gravimetric Energy",
x = "Avg. Voltage",
y = "Gravimetric Energy") +
theme_minimal()
mp %>% filter(Average.Voltage < 20) %>% ggplot(aes(x = Average.Voltage, y=Volumetric.Energy)) +
geom_point(color = "blue") + # Scatter points
geom_smooth(method = "lm") + # Regression line
labs(title = "Correlation between Average Voltage and Volumetric Energy",
x = "Avg. Voltage",
y = "Volumetric Energy") +
theme_minimal()
mp %>% ggplot(aes(x = Atomic.Fraction.Discharge, y=Gravimetric.Capacity)) +
geom_point(color = "blue") + # Scatter points
geom_smooth(method = "lm") + # Regression line
labs(title = "Correlation between Atomic Fraction Discharge and Gravimetric Capacity",
x = "Atomic Fraction Discharge",
y = "Gravimetric Capacity") +
theme_minimal()
mp %>% ggplot(aes(x = Atomic.Fraction.Discharge, y=Volumetric.Capacity)) +
geom_point(color = "blue") + # Scatter points
geom_smooth(method = "lm", se=F) + # Regression line
labs(title = "Correlation between Atomic Fraction Discharge and Volumetric Capacity",
x = "Atomic Fraction Discharge",
y = "Volumetric Capacity") +
theme_minimal()
Interesting would be also checking the relation between non numerical fields and numerical. Unfortunately only Working.Ion field presents few enough options (10) to convert to factor.
Correlation is calculated as eta squared factor, which be definition marks correlation as significant for values higher than 0.14. This means, by the achieved results, that there exists strong correlation between factored value of working ion and two of twelve numeric fields in the set, namely average voltage and volumetric capacity, which is logical from the chemical point of view.
eta_squared <- function(numeric_col, factor_col) {
fit <- aov(numeric_col ~ factor_col)
anova <- summary(fit)[[1]]
ss_total <- sum(anova$`Sum Sq`)
ss_between <- anova$`Sum Sq`[1]
eta_sq <- ss_between / ss_total
return(eta_sq)
}
eta_sq_results <- sapply(mp %>% select_if(is.numeric), function(x) eta_squared(x, mp$Working.Ion))
eta_sq_results <- unlist(eta_sq_results)
eta_sq_df <- data.frame(
Eta_Squared = eta_sq_results
)
print(eta_sq_df)
## Eta_Squared
## Max.Delta.Volume 0.010422197
## Average.Voltage 0.172330002
## Gravimetric.Capacity 0.093020927
## Volumetric.Capacity 0.206778595
## Gravimetric.Energy 0.083795109
## Volumetric.Energy 0.087632436
## Atomic.Fraction.Charge 0.029284641
## Atomic.Fraction.Discharge 0.024075528
## Stability.Charge 0.009908295
## Stability.Discharge 0.015756057
## Steps 0.033819072
## Max.Voltage.Step 0.018526945
Based on eta scored computed above, the two fields, that show good signs of correlation are plotted against factored Ion field. The results are presented using violin plots, with outliers marked in red.
As seen at the plot of Average Voltage, one can note, that results for alkali metals follow a certain pattern, starting with Lithium (Li) with dense center and long yet thin tails, consisting mostly of outliers. The following ions become more and more uniformly distributed across whole spectrum, with thicker tails and not so clearly visible center line. Alkaline earth metals (here in form of Magnesium and Calcium) follow similar pattern. Additionally one can observe, that the higher the atomic number of an ion, the lower is the median of distribution. Non alkaline metals are clearly distinguishable from alkali, but there is little to none similarity between them.
In case of volumetric capacity, the trend of lowering the center is visible only inside alkali metals, with them becoming more densely distributed around the median and losing their upper tail. Clearly distinctive among others are two heaviest metals (Aluminum and Yttrium), which retain notable distribution even very far from median.
mp %>%
ggplot(aes(x = Working.Ion, y = Average.Voltage)) +
geom_violin(scale = "width", fill="lightblue") +
geom_boxplot(width = 0.2, fill = "white", color = "black", outlier.color = "red", outlier.size = 2, alpha = 0.7) +
coord_cartesian(ylim = c(-5, 10))+
labs(title = "Violin Plot of Working Ion by Average Voltage",
x = "Working Ion",
y = "Avg. Voltage")+
theme_minimal()
mp %>%
ggplot(aes(x = Working.Ion, y = Volumetric.Capacity)) +
geom_violin(scale = "width", fill="lightblue") +
geom_boxplot(width = 0.2, fill = "white", color = "black", outlier.color = "red", outlier.size = 2, alpha = 0.7) +
coord_cartesian(ylim = c(-1, 4000))+
labs(title = "Violin Plot of Working Ion by Volumetric Capacity",
x = "Working Ion",
y = "Vol. Capacity")+
theme_minimal()
Additional correlation plot was prepared (volumetric capacity x average voltage) to check the influence of working ion used in the battery on that correlation. As see, some ions are distributed more densely over one point, other presents some logical correlation.
static_plot <- mp %>% filter(Average.Voltage < 20) %>%
ggplot(aes(x = Average.Voltage, y=Volumetric.Capacity, color=Working.Ion)) +
geom_point() + # Scatter points
geom_smooth(method = "lm", se=F) + # Regression line
labs(title = "Correlation between Average Voltage and Volumetric Capacity per Ion",
x = "Avg. Voltage",
y = "Volumetric Capacity") +
theme_minimal()
interactive_plot <- ggplotly(static_plot)
interactive_plot
Due to high demand of fast charging, long-lived and highly capacious batteries to power all the devices from phones to electric vehicles research on new battery materials is currently one of the most important. Some of the most popular trends in this area of research include:
Based on the data in the set, it is possible to create a machine learning model, that may be used to predict some features of a new type of battery.
To present this possibility, a model is prepared that predicts the used Ion based on the characteristics of battery. First step was to clear the data, by removing unnecessary, text columns. Then division on training and test sets is performed (with probability equal to 0.75, meaning that three quarters of all record will be put into training set). Data is then fit into model, by generating a random forest (“rf” as a method), with ten trees in the forest. Based on the learned model prediction is made and the result is saved as a confusion matrix.
mp_for_ML <- mp %>% select(-c(Battery.Formula, Formula.Charge, Formula.Discharge))
inTraining <-
createDataPartition(
y = mp_for_ML$Working.Ion,
p = .75,
list = FALSE)
training <- mp_for_ML[ inTraining,]
testing <- mp_for_ML[-inTraining,]
ctrl <- trainControl(
method = "repeatedcv",
number = 2,
repeats = 5)
fit <- train(Working.Ion ~ .,
data = training,
method = "rf",
trControl = ctrl,
ntree = 10)
rfClasses <- predict(fit, newdata = testing)
cm <- confusionMatrix(rfClasses, testing$Working.Ion)
cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Al Ca Cs K Li Mg Na Rb Y Zn
## Al 9 0 0 0 0 0 0 0 0 0
## Ca 1 51 0 1 2 17 0 1 7 5
## Cs 0 0 1 0 0 0 1 2 0 1
## K 0 0 2 3 2 0 0 4 0 0
## Li 4 22 0 10 573 14 52 2 5 2
## Mg 2 24 0 0 6 64 0 0 1 11
## Na 1 5 1 6 22 1 19 1 1 0
## Rb 0 0 4 6 1 0 2 2 0 0
## Y 5 1 0 0 0 0 0 0 8 1
## Zn 1 5 0 0 4 9 3 0 1 71
##
## Overall Statistics
##
## Accuracy : 0.7396
## 95% CI : (0.7124, 0.7655)
## No Information Rate : 0.5633
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5765
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: Al Class: Ca Class: Cs Class: K Class: Li Class: Mg
## Sensitivity 0.39130 0.47222 0.1250000 0.11538 0.9393 0.60952
## Specificity 1.00000 0.96513 0.9962791 0.99243 0.7653 0.95501
## Pos Pred Value 1.00000 0.60000 0.2000000 0.27273 0.8377 0.59259
## Neg Pred Value 0.98696 0.94289 0.9935065 0.97854 0.9073 0.95795
## Prevalence 0.02124 0.09972 0.0073869 0.02401 0.5633 0.09695
## Detection Rate 0.00831 0.04709 0.0009234 0.00277 0.5291 0.05910
## Detection Prevalence 0.00831 0.07849 0.0046168 0.01016 0.6316 0.09972
## Balanced Accuracy 0.69565 0.71868 0.5606395 0.55391 0.8523 0.78227
## Class: Na Class: Rb Class: Y Class: Zn
## Sensitivity 0.24675 0.166667 0.347826 0.78022
## Specificity 0.96223 0.987862 0.993396 0.97681
## Pos Pred Value 0.33333 0.133333 0.533333 0.75532
## Neg Pred Value 0.94347 0.990637 0.985955 0.97978
## Prevalence 0.07110 0.011080 0.021237 0.08403
## Detection Rate 0.01754 0.001847 0.007387 0.06556
## Detection Prevalence 0.05263 0.013850 0.013850 0.08680
## Balanced Accuracy 0.60449 0.577264 0.670611 0.87852
This matrix was later standardized to percent of correct guesses in each of original categories and is presented as a heatmap. As seen, the model performs fairly well. It tends to favor more common categories (Lithium in particular) and does a poor job on not very numerous classes (like in the case of Potassium and Cesium). But in general, the main diagonal (from left lower corner upward) is covered pretty well. To achieve better results, other method may be used, or the data should resampled beforhand to guarantee uniform distribution of classes.
conf_matrix <- as.matrix(cm$table)
conf_matrix_percent <- sweep(conf_matrix, 2, colSums(conf_matrix), FUN = "/") * 100
conf_melt <- melt(conf_matrix_percent)
conf_melt %>% ggplot(aes(x = Prediction, y = Reference, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient(high = "navy", low = "white", name = "Percentage") +
labs(x = "Predicted Label", y = "True Label", title = "Confusion Matrix Heatmap") +
theme_minimal()